package org.zjvis.datascience.common.util;

import cn.hutool.core.lang.Pair;
import com.alibaba.fastjson.JSONArray;
import com.google.common.base.Joiner;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.sql.SqlHelper;

import java.io.*;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @description Task功能类，判断节点对应结果的类型和数据操作
 * @date 2021-10-29
 */
public class ToolUtil {

    private final static Logger logger = LoggerFactory.getLogger("ToolUtil");

    public final static String FUNC_NAME_PAT = "(.*?)(def\\s{1,}begin_alg\\(.*?\\))(.*)";
    public static Pattern pattern = Pattern
            .compile(FUNC_NAME_PAT, Pattern.MULTILINE | Pattern.DOTALL);

    public static final String ID_COL = "_record_id_";

    private static Set<String> charTypes;

    private static Set<String> dateTypes;

    private static Set<String> numericTypes;

    private static Set<String> intTypes;

    private static Set<String> floatTypes;

    public static final String STRING_TYPE = "string";

    public static final String NUMBER_TYPE = "number";

    public static final String INT_TYPE = "int";

    public static final String FLOAT_TYPE = "float";

    public static final String DATE_TYPE = "date";

    // public static final String PY_FUNC_PATTERN = "(.*)(def\\s{1,}.*?\\(.*?\\):)(.*)";
    public static final String PY_FUNC_PATTERN = "def\\s{1,}.*?\\(.*?\\):";

    static {
        charTypes = new HashSet<>();
        charTypes.add("char");
        charTypes.add("varchar");
        charTypes.add("longvarchar");
        charTypes.add("longnvarchar");
        charTypes.add("text");
        charTypes.add("character");
        charTypes.add("character varying");
        charTypes.add("nchar");
        charTypes.add("nvarchar");
        charTypes.add("clob");
        charTypes.add("nclob");
        charTypes.add("blob");
        dateTypes = new HashSet<>();
        dateTypes.add("date");
        dateTypes.add("timestamp");
        dateTypes.add("time");
        dateTypes.add("timestamp_with_timezone");
        dateTypes.add("timestamp without time zone");
        numericTypes = new HashSet<>();
        numericTypes.add("integer");
        numericTypes.add("numeric");
        numericTypes.add("smallint");
        numericTypes.add("int");
        numericTypes.add("bigint");
        numericTypes.add("double precision");
        numericTypes.add("int2");
        numericTypes.add("int4");
        numericTypes.add("int8");
        numericTypes.add("float8");
        numericTypes.add("float4");
        numericTypes.add("decimal");
        numericTypes.add("float");
        numericTypes.add("real");

        intTypes = new HashSet<>();
        intTypes.add("int");
        intTypes.add("smallint");
        intTypes.add("integer");
        intTypes.add("int2");
        intTypes.add("int4");
        intTypes.add("bigint");
        intTypes.add("int8");

        floatTypes = new HashSet<>();
        floatTypes.add("double precision");
        floatTypes.add("float2");
        floatTypes.add("float8");
        floatTypes.add("float4");
    }

    public static String checkDuplicateName(String name, List<String> nameList) {
        return checkDuplicateName(name, nameList, false);
    }

    public static String checkDuplicateName(String name, List<String> nameList,
                                            boolean needUpdate) {
        List<Integer> ids = new ArrayList<>();
        String result = StringUtils.EMPTY;
        for (String str : nameList) {
            if (str.equals(name)) {
                ids.add(0);
            } else if (str.matches(name + "_\\d+")) {
                String id = str.substring(str.lastIndexOf('_') + 1, str.length());
                ids.add(Integer.parseInt(id));
            }
        }
        if (ids.isEmpty()) {
            result = name;
        } else {
            int newId = Collections.max(ids) + 1;
            result = String.format("%s_%s", name, String.valueOf(newId));
        }
        if (needUpdate) {
            nameList.add(result);
        }
        return result;

    }

    public String getFilePath() {
        String fileName = this.getClass().getClassLoader().getResource("template/algo/dbscan.json")
                .getPath();
        return fileName;
    }

    public String readJsonFile(String fileName) {
        return readRawFile(fileName, false);
    }

    public String readRawFile(String fileName, boolean withLineSep) {
        BufferedReader reader = null;
        try {
            InputStream resourceAsStream = this.getClass().getClassLoader().getResourceAsStream(fileName);
            InputStreamReader inputStreamReader = new InputStreamReader(resourceAsStream, "UTF-8");
            reader = new BufferedReader(inputStreamReader);
            StringBuffer sb = new StringBuffer();
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
                if (withLineSep) {
                    sb.append("\n");
                }
            }
            reader.close();
            return sb.toString();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (Exception e2){
            return "";
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return "";
    }

    /**
     * 校验表名
     * @param tableName
     * @param timeStamp
     * @return
     */
    public static String alignTableName(String tableName, long timeStamp) {
        //清洗节点运行后会更新子节点的sql，会传入完整表名，此完整表名不需要拼时间戳
        if ((tableName.startsWith(SqlTemplate.SCHEMA) || tableName.startsWith(SqlTemplate.SOURCE_SCHEMA)) && !tableName.endsWith("_")) {
            return tableName;
        }
        if (tableName.startsWith("graph") || tableName.substring(0,7).toLowerCase().contains("select")) {
            return tableName;
        }

        String[] names = tableName.split("\\.");
        // 创建的视图表，或者中间结果表
        if (names[names.length - 1].startsWith("view_") || names[names.length - 1]
                .startsWith("solid_")) {
            logger.debug("view table or solid table");
            if (timeStamp != 0L) {
                tableName += timeStamp;
            }
            if (!tableName.startsWith(SqlTemplate.SCHEMA + ".")) {
                tableName = SqlTemplate.SCHEMA + "." + tableName;
            }
        } else {
            //保存模型结果到本地时时直接返回output tablename
            if (tableName.startsWith("ml_model")) {
                return tableName;
            }
            //原数据表
            String schema = String.format("%s.", SqlTemplate.SOURCE_SCHEMA);
            if (!tableName.startsWith(schema)) {
                tableName = schema + tableName;
            }
        }
        return tableName;
    }

    public static void main(String[] args) {
        ToolUtil obj = new ToolUtil();
//        obj.getFilePath();
//        String json = new ToolUtil().readJsonFile("template/algo/dbscan.json");
//        System.out.println(json);
//        JSONArray jsonArray = JSONArray.parseArray(json);
        String name = "100_复制";
        ArrayList<String> list = new ArrayList<>();
        list.add(name);
        list.add("100_复制_1");
        String s = checkDuplicateName(name, list, false);
        System.out.println(s);
        System.out.println(list);
    }

    /**
     * getSpecColumnType
     *
     * @param columnNames
     * @param columnTypes
     * @param columnName
     * @return
     */
    public static String getSpecColumnType(List<String> columnNames, List<String> columnTypes,
                                           String columnName) {
        if (columnNames == null || columnTypes == null || columnName == null) {
            return "";
        }
        int index = 0;
        while (index < columnNames.size()) {
            if (columnName.equals(columnNames.get(index))) {
                break;
            }
            ++index;
        }
        if (index < columnNames.size()) {
            return columnTypes.get(index);
        }
        return "";
    }

    public static Map<String, Integer> buildColumnTypes(List<String> columNames,
                                                        List<String> columnTypes) {
        Map<String, Integer> typeMaps = new HashMap<>();
        int length = columNames.size();
        for (int i = 0; i < length; ++i) {
            String key = columNames.get(i).toUpperCase();
            String type = columnTypes.get(i).toLowerCase();
            if (charTypes.contains(type) || dateTypes.contains(type)) {
                typeMaps.put(key, 12);
            } else {
                typeMaps.put(key, 12345678);
            }
        }
        return typeMaps;
    }

    public static boolean isCharType(String key) {
        return charTypes.contains(key) || dateTypes.contains(key);
    }

    public static boolean isIntType(String key) {
        return intTypes.contains(key);
    }

    public static boolean isFloatType(String key) {
        return floatTypes.contains(key);
    }

    public static boolean isDateType(String key) {
        return dateTypes.contains(key);
    }

    public static boolean isStringType(String key) {
        return charTypes.contains(key);
    }

    public static boolean checkTypeConsistence(Map<String, String> typeMap, List<String> keys,
                                               String type) {
        boolean isConsistence = true;
        for (String key : keys) {
            String value = typeMap.get(key).toLowerCase();
            if (type.equals(NUMBER_TYPE)) {
                if (!numericTypes.contains(value)) {
                    isConsistence = false;
                    break;
                }
            } else if (type.equals(STRING_TYPE)) {
                if (!charTypes.contains(value)) {
                    isConsistence = false;
                    break;
                }
            } else if (type.equals(DATE_TYPE)) {
                if (!dateTypes.contains(value)) {
                    isConsistence = false;
                    break;
                }
            }
        }
        return isConsistence;
    }

    public static Pair<List<String>, List<String>> filterTypeAndCol(List<String> inputCols,
                                                                    List<String> columnTypes, String filterType) {
        List<String> filterCols = new ArrayList<>();
        List<String> filterTypes = new ArrayList<>();
        switch (filterType) {
            case ToolUtil.DATE_TYPE: {
                for (int i = 0; i < columnTypes.size(); ++i) {
                    String e = columnTypes.get(i).toLowerCase();
                    if (ToolUtil.isDateType(e) && !ID_COL.equals(inputCols.get(i))) {
                        filterCols.add(inputCols.get(i));
                        filterTypes.add(columnTypes.get(i));
                    }
                }
                break;
            }
            case ToolUtil.FLOAT_TYPE: {
                for (int i = 0; i < columnTypes.size(); ++i) {
                    String e = columnTypes.get(i).toLowerCase();
                    if (ToolUtil.isFloatType(e)) {
                        filterCols.add(inputCols.get(i));
                        filterTypes.add(columnTypes.get(i));
                    }
                }
                break;
            }
            case ToolUtil.INT_TYPE: {
                for (int i = 0; i < columnTypes.size(); ++i) {
                    String e = columnTypes.get(i).toLowerCase();
                    if (ToolUtil.isIntType(e) && !ID_COL.equals(inputCols.get(i))) {
                        filterCols.add(inputCols.get(i));
                        filterTypes.add(columnTypes.get(i));
                    }
                }
                break;
            }
            case ToolUtil.STRING_TYPE: {
                for (int i = 0; i < columnTypes.size(); ++i) {
                    String e = columnTypes.get(i).toLowerCase();
                    if (ToolUtil.isStringType(e) && !ID_COL.equals(inputCols.get(i))) {
                        filterCols.add(inputCols.get(i));
                        filterTypes.add(columnTypes.get(i));
                    }
                }
                break;
            }
            case ToolUtil.NUMBER_TYPE: {
                for (int i = 0; i < columnTypes.size(); ++i) {
                    String e = columnTypes.get(i).toLowerCase();
                    if (ToolUtil.numericTypes.contains(e) && !ID_COL.equals(inputCols.get(i))) {
                        filterCols.add(inputCols.get(i));
                        filterTypes.add(columnTypes.get(i));
                    }
                }
                break;
            }
            default:
                for (int i = 0; i < inputCols.size(); ++i) {
                    if (!ID_COL.equals(inputCols.get(i))) {
                        filterCols.add(inputCols.get(i));
                        filterTypes.add(columnTypes.get(i));
                    }
                }
        }
        return new Pair<>(filterCols, filterTypes);
    }

    public static String extractEntryFunction(String script) {
        Matcher m = pattern.matcher(script);
        if (m.find()) {
            String functionDef = m.group(2);
            int index = functionDef.indexOf("begin_alg");
            if (index != -1) {
                return functionDef.substring(index);
            }
        }
        return "";
    }

    /**
     * 过滤掉以prefix开头的key
     *
     * @param keys
     * @param prefix
     */
    public static void startsWithFilter(List<String> keys, String prefix) {
        Iterator<String> iterator = keys.iterator();
        prefix = prefix.replaceAll("\"", "");
        while (iterator.hasNext()) {
            String key = iterator.next();
            String[] tmps = key.split("\\.");
            if (tmps[tmps.length - 1].startsWith(prefix)) {
                iterator.remove();
            }
        }
    }

    public static void listWithIndex(List<String> list) {
        int length = list.size();
        for (int i = 0; i < length; ++i) {
            list.set(i, String.format("%s#%s", list.get(i), i));
        }
    }

    public static void sortList(List<String> list) {
        Collections.sort(list, String::compareTo);
    }

    public static void getColAndType(List<String> list, List<String> colsResult,
                                     List<String> refType, List<String> typesResult) {
        Iterator<String> iterator = list.iterator();
        while (iterator.hasNext()) {
            String tmpKey = iterator.next();
            int index = tmpKey.lastIndexOf("#");
            if (index != -1) {
                String key = tmpKey.substring(0, index);
                colsResult.add(key);
                int i = Integer.parseInt(tmpKey.substring(index + 1));
                typesResult.add(refType.get(i));
            }
        }
    }

    public static String standardString(String text, char separator) {
        String res = "";
        text = text.trim();
        for (int i = 0; i < text.length(); i++) {
            char ch = text.charAt(i);
            int ascii = (int) ch;
            if (ascii >= 65 && ascii <= 90) {
                res += (char) ((int) ch + 32);
            } else if ((ascii >= 48 && ascii <= 57) || (ascii >= 97 && ascii <= 122)
                    || ascii > 126) {
                res += ch;
            } else {
                res += separator;
            }
        }
        return res;
    }

    /**
     * 给column加上双引号
     *
     * @return
     */
    public static String quoteWrapper(String column) {
        if (StringUtils.isNotEmpty(column)) {
            String[] tmps = column.split("\\.");
            List<String> lists = new ArrayList<>();
            for (String s : tmps) {
                lists.add(String.format("\"%s\"", s));
            }
            return Joiner.on("\\.").join(lists);
        }
        return column;
    }

    public static List<String> extractPythonFunctionName(String text) {
        Pattern r = Pattern.compile(PY_FUNC_PATTERN);
        Matcher m = r.matcher(text);
        List<String> functions = new ArrayList<>();
        while (m.find()) {
            functions.add(m.group());
        }
        return functions;
    }

    public static List<String> extractParamsForFunc(String func) {
        int start = func.indexOf("(");
        int end = func.lastIndexOf(")");
        if (start == -1 || end == -1) {
            return null;
        }
        List<String> params = Arrays.asList(func.substring(start + 1, end).split("\\s*,\\s*"));
        return params;
    }

    public static void getSqlHelpers(JSONArray input, List<SqlHelper> sqlHelpers) {
        for (int i = 0; i < input.size(); ++i) {
            SqlHelper sqlHelper = new SqlHelper();
            List<String> tableCols = input.getJSONObject(i).getJSONArray("tableCols")
                    .toJavaList(String.class);
            List<String> tableTypes = input.getJSONObject(i).getJSONArray("columnTypes")
                    .toJavaList(String.class);
            Map<String, Integer> typeMaps = ToolUtil.buildColumnTypes(tableCols, tableTypes);
            sqlHelper.bind(typeMaps);
            // sqlHelper.bind(gpDataProvider.getColumnTypes(new Table(1L, tableName)));
            sqlHelpers.add(sqlHelper);
        }
    }


}
